{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 29,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "from torchvision import datasets, transforms\n",
    "import torch.utils.data\n",
    "import numpy as np\n",
    "from sklearn.cluster import KMeans\n",
    "import math\n",
    "from copy import deepcopy"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "metadata": {},
   "outputs": [],
   "source": [
    "def k_means_cpu(weight, n_clusters, init='k-means++', max_iter=50):\n",
    "    # flatten the weight for computing k-means\n",
    "    org_shape = weight.shape\n",
    "    weight = weight.reshape(-1, 1)  # single feature\n",
    "    if n_clusters > weight.size:\n",
    "        n_clusters = weight.size\n",
    "\n",
    "    k_means = KMeans(n_clusters=n_clusters, init=init, n_init=1, max_iter=max_iter)\n",
    "    k_means.fit(weight)\n",
    "\n",
    "    centroids = k_means.cluster_centers_\n",
    "    labels = k_means.labels_\n",
    "    labels = labels.reshape(org_shape)\n",
    "    return torch.from_numpy(centroids).cuda().view(1, -1), torch.from_numpy(labels).int().cuda()\n",
    "\n",
    "\n",
    "def reconstruct_weight_from_k_means_result(centroids, labels):\n",
    "    weight = torch.zeros_like(labels).float().cuda()\n",
    "    for i, c in enumerate(centroids.cpu().numpy().squeeze()):\n",
    "        weight[labels == i] = c.item()\n",
    "    return weight\n",
    "\n",
    "\n",
    "class QuantLinear(nn.Linear):\n",
    "    def __init__(self, in_features, out_features, bias=True):\n",
    "        super(QuantLinear, self).__init__(in_features, out_features, bias)\n",
    "        self.weight_labels = None\n",
    "        self.bias_labels = None\n",
    "        self.num_cent = None\n",
    "        self.quant_flag = False\n",
    "        self.quant_bias = False\n",
    "        \n",
    "    def kmeans_quant(self, bias=False, quantize_bit=4):\n",
    "        self.num_cent = 2 ** quantize_bit\n",
    "        \n",
    "        w = self.weight.data\n",
    "        centroids, self.weight_labels = k_means_cpu(w.cpu().numpy(), self.num_cent)\n",
    "        w_q = reconstruct_weight_from_k_means_result(centroids, self.weight_labels)\n",
    "        self.weight.data = w_q.float()\n",
    "        \n",
    "        if bias:\n",
    "            b = self.bias.data\n",
    "            centroids, self.bias_labels = k_means_cpu(b.cpu().numpy(), self.num_cent)\n",
    "            b_q = reconstruct_weight_from_k_means_result(centroids, self.bias_labels)\n",
    "            self.bias.data = b_q.float()\n",
    "        \n",
    "        self.quant_flag = True\n",
    "        self.quant_bias = bias\n",
    "    \n",
    "    def kmeans_update(self):\n",
    "        if not self.quant_flag:\n",
    "            return\n",
    "        \n",
    "        new_weight_data = torch.zeros_like(self.weight_labels).float().cuda()\n",
    "        for i in range(self.num_cent):\n",
    "            mask_cl = (self.weight_labels == i).float()\n",
    "            new_weight_data += (self.weight.data * mask_cl).sum() / mask_cl.sum() * mask_cl\n",
    "        self.weight.data = new_weight_data\n",
    "        \n",
    "        if self.quant_bias:\n",
    "            new_bias_data = torch.zeros_like(self.bias_labels).float().cuda()\n",
    "            for i in range(self.num_cent):\n",
    "                mask_cl = (self.bias_labels == i).float()\n",
    "                new_bias_data += (self.bias.data * mask_cl).sum() / mask_cl.sum() * mask_cl\n",
    "            self.bias.data = new_bias_data\n",
    "\n",
    "            \n",
    "class QuantConv2d(nn.Conv2d):\n",
    "    def __init__(self, in_channels, out_channels, kernel_size, stride=1,\n",
    "                 padding=0, dilation=1, groups=1, bias=True):\n",
    "        super(QuantConv2d, self).__init__(in_channels, out_channels, \n",
    "            kernel_size, stride, padding, dilation, groups, bias)\n",
    "        self.weight_labels = None\n",
    "        self.bias_labels = None\n",
    "        self.num_cent = None\n",
    "        self.quant_flag = False\n",
    "        self.quant_bias = False\n",
    "        \n",
    "    def kmeans_quant(self, bias=False, quantize_bit=4):\n",
    "        self.num_cent = 2 ** quantize_bit\n",
    "        \n",
    "        w = self.weight.data\n",
    "        centroids, self.weight_labels = k_means_cpu(w.cpu().numpy(), self.num_cent)\n",
    "        w_q = reconstruct_weight_from_k_means_result(centroids, self.weight_labels)\n",
    "        self.weight.data = w_q.float()\n",
    "        \n",
    "        if bias:\n",
    "            b = self.bias.data\n",
    "            centroids, self.bias_labels = k_means_cpu(b.cpu().numpy(), self.num_cent)\n",
    "            b_q = reconstruct_weight_from_k_means_result(centroids, self.bias_labels)\n",
    "            self.bias.data = b_q.float()\n",
    "        \n",
    "        self.quant_flag = True\n",
    "        self.quant_bias = bias\n",
    "    \n",
    "    def kmeans_update(self):\n",
    "        if not self.quant_flag:\n",
    "            return\n",
    "        \n",
    "        new_weight_data = torch.zeros_like(self.weight_labels).float().cuda()\n",
    "        for i in range(self.num_cent):\n",
    "            mask_cl = (self.weight_labels == i).float()\n",
    "            new_weight_data += (self.weight.data * mask_cl).sum() / mask_cl.sum() * mask_cl\n",
    "        self.weight.data = new_weight_data\n",
    "        \n",
    "        if self.quant_bias:\n",
    "            new_bias_data = torch.zeros_like(self.bias_labels).float().cuda()\n",
    "            for i in range(self.num_cent):\n",
    "                mask_cl = (self.bias_labels == i).float()\n",
    "                new_bias_data += (self.bias.data * mask_cl).sum() / mask_cl.sum() * mask_cl\n",
    "            self.bias.data = new_bias_data\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 31,
   "metadata": {},
   "outputs": [],
   "source": [
    "class ConvNet(nn.Module):\n",
    "    def __init__(self):\n",
    "        super(ConvNet, self).__init__()\n",
    "\n",
    "        self.conv1 = QuantConv2d(1, 32, kernel_size=3, padding=1, stride=1)\n",
    "        self.relu1 = nn.ReLU(inplace=True)\n",
    "        self.maxpool1 = nn.MaxPool2d(2)\n",
    "\n",
    "        self.conv2 = QuantConv2d(32, 64, kernel_size=3, padding=1, stride=1)\n",
    "        self.relu2 = nn.ReLU(inplace=True)\n",
    "        self.maxpool2 = nn.MaxPool2d(2)\n",
    "\n",
    "        self.conv3 = QuantConv2d(64, 64, kernel_size=3, padding=1, stride=1)\n",
    "        self.relu3 = nn.ReLU(inplace=True)\n",
    "\n",
    "        self.linear1 = QuantLinear(7*7*64, 10)\n",
    "        \n",
    "    def forward(self, x):\n",
    "        out = self.maxpool1(self.relu1(self.conv1(x)))\n",
    "        out = self.maxpool2(self.relu2(self.conv2(out)))\n",
    "        out = self.relu3(self.conv3(out))\n",
    "        out = out.view(out.size(0), -1)\n",
    "        out = self.linear1(out)\n",
    "        return out\n",
    "\n",
    "    def kmeans_quant(self, bias=False, quantize_bit=4):\n",
    "        # Should be a less manual way to quantize\n",
    "        # Leave it for the future\n",
    "        self.conv1.kmeans_quant(bias, quantize_bit)\n",
    "        self.conv2.kmeans_quant(bias, quantize_bit)\n",
    "        self.conv3.kmeans_quant(bias, quantize_bit)\n",
    "        self.linear1.kmeans_quant(bias, quantize_bit)\n",
    "    \n",
    "    def kmeans_update(self):\n",
    "        self.conv1.kmeans_update()\n",
    "        self.conv2.kmeans_update()\n",
    "        self.conv3.kmeans_update()\n",
    "        self.linear1.kmeans_update()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 32,
   "metadata": {},
   "outputs": [],
   "source": [
    "def train(model, device, train_loader, optimizer, epoch):\n",
    "    model.train()\n",
    "    total = 0\n",
    "    for batch_idx, (data, target) in enumerate(train_loader):\n",
    "        data, target = data.to(device), target.to(device)\n",
    "        optimizer.zero_grad()\n",
    "        output = model(data)\n",
    "        loss = F.cross_entropy(output, target)\n",
    "        loss.backward()\n",
    "        optimizer.step()\n",
    "        total += len(data)\n",
    "        progress = math.ceil(batch_idx / len(train_loader) * 50)\n",
    "        print(\"\\rTrain epoch %d: %d/%d, [%-51s] %d%%\" %\n",
    "              (epoch, total, len(train_loader.dataset),\n",
    "               '-' * progress + '>', progress * 2), end='')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 33,
   "metadata": {},
   "outputs": [],
   "source": [
    "def test(model, device, test_loader):\n",
    "    model.eval()\n",
    "    test_loss = 0\n",
    "    correct = 0\n",
    "    with torch.no_grad():\n",
    "        for data, target in test_loader:\n",
    "            data, target = data.to(device), target.to(device)\n",
    "            output = model(data)\n",
    "            test_loss += F.cross_entropy(output, target, reduction='sum').item()  # sum up batch loss\n",
    "            pred = output.argmax(dim=1, keepdim=True)  # get the index of the max log-probability\n",
    "            correct += pred.eq(target.view_as(pred)).sum().item()\n",
    "\n",
    "    test_loss /= len(test_loader.dataset)\n",
    "\n",
    "    print('\\nTest: average loss: {:.4f}, accuracy: {}/{} ({:.0f}%)'.format(\n",
    "        test_loss, correct, len(test_loader.dataset),\n",
    "        100. * correct / len(test_loader.dataset)))\n",
    "    return test_loss, correct / len(test_loader.dataset)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 36,
   "metadata": {},
   "outputs": [],
   "source": [
    "def main():\n",
    "    epochs = 2\n",
    "    batch_size = 64\n",
    "    torch.manual_seed(0)\n",
    "\n",
    "    device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
    "\n",
    "    train_loader = torch.utils.data.DataLoader(\n",
    "        datasets.MNIST('../data/MNIST', train=True, download=False,\n",
    "                       transform=transforms.Compose([\n",
    "                           transforms.ToTensor(),\n",
    "                           transforms.Normalize((0.1307,), (0.3081,))\n",
    "                       ])),\n",
    "        batch_size=batch_size, shuffle=True)\n",
    "    test_loader = torch.utils.data.DataLoader(\n",
    "        datasets.MNIST('../data/MNIST', train=False, download=False, transform=transforms.Compose([\n",
    "            transforms.ToTensor(),\n",
    "            transforms.Normalize((0.1307,), (0.3081,))\n",
    "        ])),\n",
    "        batch_size=1000, shuffle=True)\n",
    "\n",
    "    model = ConvNet().to(device)\n",
    "    optimizer = torch.optim.Adadelta(model.parameters())\n",
    "    \n",
    "    for epoch in range(1, epochs + 1):\n",
    "        train(model, device, train_loader, optimizer, epoch)\n",
    "        _, acc = test(model, device, test_loader)\n",
    "    \n",
    "    quant_model = deepcopy(model)\n",
    "    print('=='*10)\n",
    "    print('2 bits quantization')\n",
    "    quant_model.kmeans_quant(bias=False, quantize_bit=2)\n",
    "    _, acc = test(quant_model, device, test_loader)\n",
    "        \n",
    "    return model, quant_model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 37,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train epoch 1: 60000/60000, [-------------------------------------------------->] 100%\n",
      "Test: average loss: 0.0503, accuracy: 9827/10000 (98%)\n",
      "Train epoch 2: 60000/60000, [-------------------------------------------------->] 100%\n",
      "Test: average loss: 0.0262, accuracy: 9913/10000 (99%)\n",
      "====================\n",
      "2 bits quantization\n",
      "\n",
      "Test: average loss: 0.0609, accuracy: 9889/10000 (99%)\n"
     ]
    }
   ],
   "source": [
    "model, quant_model = main()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 可视化"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 38,
   "metadata": {},
   "outputs": [],
   "source": [
    "from matplotlib import pyplot as plt"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 39,
   "metadata": {},
   "outputs": [],
   "source": [
    "def plot_weights(model):\n",
    "    modules = [module for module in model.modules()]\n",
    "    num_sub_plot = 0\n",
    "    for i, layer in enumerate(modules):\n",
    "        if hasattr(layer, 'weight'):\n",
    "            plt.subplot(221+num_sub_plot)\n",
    "            w = layer.weight.data\n",
    "            w_one_dim = w.cpu().numpy().flatten()\n",
    "            plt.hist(w_one_dim, bins=50)\n",
    "            num_sub_plot += 1\n",
    "    plt.show()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 40,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAX0AAAD4CAYAAAAAczaOAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjEsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy8QZhcZAAAYj0lEQVR4nO3df5BdZX3H8feniUCtVRKzYATCQhtaaatBV8DqIALyy2roDNQwpQRlJh2FmTK2MyzFGVrUGpxprc44YCopwbYE0DJkEpSGAIMyRkkQ+VnIgtRsk5IggiAWDX77x3mWXjb37t7de865P57Pa+bOPfc5z57nufc+57vPfc45z1FEYGZmefi1blfAzMzq46BvZpYRB30zs4w46JuZZcRB38wsI3O7XYGpLFiwIIaHh7tdDRtgW7dufToihuou123bqjRVu+7poD88PMyWLVu6XQ0bYJL+qxvlum1blaZq1x7eMTPLiIO+mVlGHPTNzDLS02P6g2B4dMMry0+u/MDAl2vWLrfR7nBP38wsIw76ZmYZ8fCOmdWmcUjHusNB38y6zuP79fHwjplZRtzTN7NKeUint7inb2aWkUqCvqTVknZJerAh7W8k/bek+9Lj9CrKNrP+Njy6wb8OKlRVT/8a4NQm6Z+PiCXpcUtFZZuZWQuVjOlHxF2ShqvYtpnlx2f3lKfuMf0LJd2fhn/mNcsgaYWkLZK27N69u+bqWU4++tGPArxt0jDkfEkbJW1Lz/NSuiR9UdJYasNvb/ib5Sn/NknL638nZu2rM+hfCfwWsATYCfx9s0wRsSoiRiJiZGio9ntbWEbOO+88gG2TkkeBTRGxGNiUXgOcBixOjxUU7RlJ84HLgGOAo4HLWnVozHpBbadsRsRTE8uS/glYX1fZZs0cd9xxAHt4dednKXB8Wl4D3AlcnNKvjYgANkvaX9LClHdjRDwDIGkjxfGs66p/B4PNB3OrUVtPP+0gE/4YeLBVXrMuOjAidgKk5wNS+kHA9oZ84ymtVfpePHRpvaCSnr6k6yh6QAskjVP8/D1e0hIggCeBP6+ibLOKqElaTJG+d2LEKmAVwMjISNM8ZlWr6uyds5skX11FWWYle0rSwojYmX6d7krp48AhDfkOBnak9OMnpd9ZQz17nodnepOvyDV7tXXAxBk4y4GbG9LPTWfxHAs8l4Z/bgVOljQvHcA9OaWZ9STPvWPZOvvsswF+l+KMzIlhyJXADZLOB34EnJWy3wKcDowBLwIfAYiIZyR9Crgn5bt84qCuWS9y0LdsXXfddaxdu/b+iBiZtOrEyXnTWTsXNNtORKwGVldQxb7jIZ3e5+EdM7OMOOibmWXEwztm1lc8D09n3NM3M8uIg76ZWUYc9M3MMuKgb2Z9y3fZmjkHfTOzjDjom5llxEHfzCwjDvpmZhnxxVlm1hEfSO0v7umbmWXEPf2S1HFpeKsyJtLLLreq9+TL6M26xz19M7OMuKdvZn3Pvx7bV0lPX9JqSbskPdiQNl/SRknb0vO8Kso2M7PWqhreuQY4dVLaKLApIhYDm9JrMzOrUSVBPyLuAibfJ3QpsCYtrwHOqKJsMzNrrc4x/QMjYidAROyUdECzTJJWACsAFi1aVGP1us/jkmZWtZ47eyciVkXESESMDA0Ndbs6ZmYDpc6e/lOSFqZe/kJgV41lm1mJfBVu/6qzp78OWJ6WlwM311i2mZlR3Smb1wHfAX5H0rik84GVwPslbQPen16bmVmNKhneiYizW6w6sYryBlFVUytM3n6VZZh1g9v21HruQK6ZmVXHQd/MLCMO+mZmGXHQNzPLiIO+mVlGPLVyB1pdoDLdhSt1n10wXXk+28EsH+7pmzUh6UlJD0i6T9KWlNZ0enAVvihpTNL9kt7e3dqbteaevllr74uIpxteT0wPvlLSaHp9MXAasDg9jgGuTM8Dx9Mv9D/39M3a12p68KXAtVHYDOyf5pcy6zkO+mbNBfAfkram6b5h0vTgwMT04AcB2xv+djylvYqkFZK2SNqye/fuCqtu1pqHd3pc3T+nu/XzvQcPJr87Inak+z5slPSfU+RVk7TYKyFiFbAKYGRkZK/1ZnVwT9+siYjYkZ53ATcBR5OmBweYND34OHBIw58fDOyor7Zm7XPQN5tE0m9I+s2JZeBk4EFaTw++Djg3ncVzLPDcxDCQWa/x8I7Z3g4EbpIExT7ybxHxTUn3ADekqcJ/BJyV8t8CnA6MAS8CH6m/ytZMDw4bdp2DvtkkEfEE8LYm6T+myfTgERHABTVUzaxjDvpmNiWfmz9YHPQnqfLnYJk7T6+d1dP4WXV6Axj/JLcqVH1jon7hA7lmZhmpvacv6UngeeBlYE9EjNRdBzOzXHVreGfynCZmZlYDD++YmWWkG0G/2ZwmZmZWg24M7+w1p0lE3DWxMv0jWAGwaNGiLlSvf3Xz1LpmZbdzFo5PB+xN/l4GV+09/RZzmjSuXxURIxExMjQ0VHf1zMwGWq1Bf4o5TczMrAZ1D+80ndOk5jqYWcZyv/iv1qDfak4TMzOrh6dhSKY7cFVV78AHzGbHl9SbzY6DvpkBeXZAchzq8cVZZmYZcdA3M8uIg76ZWUYc9M3MMpL1gdzcD1z1W7k5fl918OeaF/f0zcwyknVP38xsQi6nbzrom2XIQzr58vCOmVlGBrqn3+ml+q16Q+4l9Y5W38Ug/zw364R7+mZmkwyPbhjYzp2DvplZRgZ6eMfM/t+g9lyrNIhn9Linb2aWEff0zczaMCi9/r4N+q2+AP+EHSyz/T4HZQftlPcHm6xvg76ZWbf086nCtQd9SacCXwDmAF+JiJV118GsbL3Qrt2rt3bUeiBX0hzgS8BpwJHA2ZKOrLMOZmVzu7YJ/XB+f909/aOBsYh4AkDSWmAp8HDN9TArU+Xt2sew+ks731e3hoLqDvoHAdsbXo8DxzRmkLQCWJFeviDp0Sm2twB4WldMXeh060u2AHi61hLb04v1qqVO03z/h5ZQxLTtGmbctifs9RnV2J672WYGpuxW31eL9LLKbtmu6w76apIWr3oRsQpY1dbGpC0RMVJGxcrSi3WC3qxXL9ZplqZt1zCztv3Khrv4GbnswSy77ouzxoFDGl4fDOyouQ5mZXO7tr5Rd9C/B1gs6TBJ+wDLgHU118GsbG7X1jdqHd6JiD2SLgRupTi1bXVEPNTBJmf0U7kmvVgn6M169WKdZqyCdt2om5+Ryx7AshWx19CjmZkNKE+4ZmaWEQd9M7OM9HzQlzRf0kZJ29LzvBb5XpZ0X3qsa0g/TNJ3099fnw60VV4nSUskfUfSQ5Lul/ThhnXXSPphQ32XdFCXUyU9KmlM0miT9fum9z2WPofhhnWXpPRHJZ0y2zrMsl6fkPRw+mw2STq0YV3T7zIHnbatqstO+b4p6VlJ60soc9btt+Jyj5N0r6Q9ks4so8wZlN1y3yhFRPT0A/gcMJqWR4ErWuR7oUX6DcCytHwV8LE66gQcASxOy28GdgL7p9fXAGeWUI85wOPA4cA+wA+AIyfl+ThwVVpeBlyflo9M+fcFDkvbmVPSd9ZOvd4HvDYtf2yiXlN9lzk8Om1bVZed1p0IfBBYX0M7adp+ayh3GHgrcG0Z++oMy265b5Tx6PmePsXl7GvS8hrgjHb/UJKAE4CvzebvO6lTRDwWEdvS8g5gFzBUQtmNXrn8PyJ+AUxc/t+qrl8DTkyfy1JgbUS8FBE/BMbS9mqpV0TcEREvppebKc5tt+62rbb2tYjYBDxfQnmdtN9Ky42IJyPifuBXHZY1m7Ir3Tf6IegfGBE7AdLzAS3y7Sdpi6TNkiYa6xuBZyNiT3o9TnHJfF11AkDS0RT/1R9vSP5M+vn2eUn7zrIezS7/n/z+XsmTPofnKD6Xdv52tma67fOBbzS8bvZd5qKMtlVL2SXopP1WXW5VOt03OtYT8+lLug14U5NVl85gM4siYoekw4HbJT0A/LRJvrbOUS2pTkhaCHwVWB4RE72GS4D/odhZVwEXA5fPZLsTm2+SNvn9tcrT1tQBs9T2tiWdA4wA721I3uu7jIgyglpPqLht1VJ2STppv1WXW5VO942O9UTQj4iTWq2T9JSkhRGxMzXyXS22sSM9PyHpTuAo4OvA/pLmpl5C25fHl1EnSa8HNgCfjIjNDdvemRZfkvTPwF+1U6cm2rn8fyLPuKS5wBuAZ9r829lqa9uSTqIINu+NiJcm0lt8lwMT9KtsW3WUXaJO2m/V5Valo32jDP0wvLMOWJ6WlwM3T84gad7EEImkBcC7gYejOBJyB3DmVH9fUZ32AW4Cro2IGyetW5ieRTFu+uAs69HO5f+NdT0TuD19LuuAZensiMOAxcD3ZlmPGddL0lHAl4EPRcSuhvSm32VJ9eoHHbWtqssuWSftt+pyqzLrfaM0ZR4VruJBMX63CdiWnuen9BGKOxQB/CHwAMWR8AeA8xv+/nCKYDYG3AjsW1OdzgF+CdzX8FiS1t2e6vkg8C/A6zqoy+nAYxQ94UtT2uUUDQZgv/S+x9LncHjD316a/u5R4LSSv7fp6nUb8FTDZ7Nuuu8yh0enbavqstPrbwG7gZ9T9FxP6Ub7rbh9vjO9t58BPwYeKvE7ntW+UdbD0zCYmWWkH4Z3zMysJA76ZmYZmTboS9pP0vck/SBd9v23Kf0wNZneYKrLplXhZf9msyFpjqTvK00p4HZtg66dUzZfAk6IiBckvQb4tqRvAJ8APh8RayVdRXERwZXp+ScR8duSlgFXAB+WdCTFkerfo7h0/DZJR0TEy60KXrBgQQwPD3fy/sym8zOKsyden15fQcXtGty2rVpbt259OiKaX6U9w6POrwXupbjp89PA3JT+LuDWtHwr8K60PDflE8UFSZc0bOuVfK0e73jHO8KsKtu3bw+KC/hOANandlp5uw63basYsCVatL22xvTTT+D7KC7W2EhxqlGr6Q06uuxf0op0Cf6W3bt3t1M9s1m56KKLoGiHE1ezTjVtRzemszArXVtBPyJejoglFFePHQ28pVm29NzRZf8RsSoiRiJiZGio7PnJzArr16/ngAMOAHixIXmqNtrxdBbu0FgvmNHZOxHxLHAncCxpeoO0qvFS4lcuM67xsn+zGbn77rtZt24dwB9QzHR4AvCPVNiu3aGxXtDO2TtDkvZPy78OnAQ8QuvpDbpx2X/Whkc3MDy6oWna5HQrfPazn2V8fByKq36XUbTTP8XteuDlvm+0c/bOQmCNpDkU/yRuiIj1kh4G1kr6NPB94OqU/2rgq5LGKHpCywAi4iFJN1DMo7IHuCCmOcPBWmvWYHNtxCW7GLfrgeT9ozBt0I/iRgJHNUl/giY33YiI/wXOarGtzwCfmXk1zaoTEXdSDFu6XdvA64mpla097qmYWac8DYOZWUYc9M3MMuKgb2aWEQd9M7OMOOibmWXEQd/MLCMO+mZmGfF5+maWrcZrX55c+YEu1qQ+7umbmWXEPf0M5NibMbPm3NM3M8uIg76ZWUY8vNPjPMmamZXJPX0zs4w46JuZZcRBPzM53ybOzBz0zcyy4qBvZpYRn71jZgPLQ5l7c0/fzCwjDvpmZhlx0Dczy8i0QV/SIZLukPSIpIck/UVKny9po6Rt6XleSpekL0oak3S/pLc3bGt5yr9N0vLq3paZmTXTTk9/D/CXEfEW4FjgAklHAqPApohYDGxKrwFOAxanxwrgSij+SQCXAccARwOXTfyjMDOzekwb9CNiZ0Tcm5afBx4BDgKWAmtStjXAGWl5KXBtFDYD+0taCJwCbIyIZyLiJ8BG4NRS342ZmU1pRqdsShoGjgK+CxwYETuh+Mcg6YCU7SBge8Ofjae0VumTy1hB8QuBRYsWzaR6NgOeY98sT20fyJX0OuDrwEUR8dOpsjZJiynSX50QsSoiRiJiZGhoqN3qmZlZG9oK+pJeQxHw/zUi/j0lP5WGbUjPu1L6OHBIw58fDOyYIt3MzGrSztk7Aq4GHomIf2hYtQ6YOANnOXBzQ/q56SyeY4Hn0jDQrcDJkualA7gnpzQzM6tJO2P67wb+DHhA0n0p7a+BlcANks4HfgScldbdApwOjAEvAh8BiIhnJH0KuCfluzwininlXZiZWVumDfoR8W2aj8cDnNgkfwAXtNjWamD1TCpoZmbl8YRrZmbkc0abp2EwM8uIe/o9ylPCms2O952puadvZpYRB30zs4w46JuZZcRB38wsIw76ZmYZcdA3hkc3vPLIyfbt2wGO8A2CLCcO+patuXPnAoz7BkGWEwd9y9bChQuhmB/KNwiybDjomzH1DYKA0m4QJGmLpC27d+8u+y2YtcVX5PaQ3MbUe8XkGwQVs4k3z9okbUY3CAJWAYyMjOy13qwO7ulb7oRvEGQZcU/fslXMAs6hwLda3CBoJXvfIOhCSWspDto+l+4PfSvwdw0Hb08GLqnhLVjiX8ntc9C3bN19990AbwRO8A2CLBcO+pat97znPQBbI2KkyWrfIMgGksf0zcwy4qBvZpYRB30zs4w46JuZZcQHcs2sb1V1qubEdgfxBunu6ZuZZWTanr6k1cAfAbsi4vdT2nzgemAYeBL4k4j4iYrr179AcS7zi8B5EXFv+pvlwCfTZj8dEWswX1RiZrVqp6d/DXvPGOipZwdUrnPrm+Vi2qAfEXcBk68u9NSzZmZ9aLZj+pVMPQueftbMrEplH8jtaOpZKKafjYiRiBgZGhoqtXJmZrmbbdD31LNmZn1otkF/YupZ2Hvq2XPTDaSPJU09C9wKnCxpXjqAe3JKMzOzGrVzyuZ1wPHAAknjFGfheOpZM7M+NG3Qj4izW6zy1LOz5NMhzaxbPA2DmfUVd5o646BvLTXuXIM4B4lZjhz0zcxaGMSOjydcMzPLiIO+mVlGHPStLZ6EzWwweEzfzHqeOxzlcdCvkRuumXWbh3fMzDLinr7NyCCewmaWE/f0zcwy4p6+mfWkXjsGNii/ch30K9ZrDdfM8ubhHTOzjDjom5llxMM7NmuDMsZpvcVDotVy0K+AG63ZYOvnDo+Hd8zMMuKevpWin3s+1n3+dVwf9/TNzDLinn5J3FMxy1O//cp10O+AA31zE59LP+wAVi/vM93noG+V6bcekFmn+qHD46A/C+6tzJz/AeTL+0tvqT3oSzoV+AIwB/hKRKysuw5mZXO73lvOwb6XOzm1Bn1Jc4AvAe8HxoF7JK2LiIfrrEe7cm60VWr2ufbajjET/dauy+b9ZGq99g+g7p7+0cBYRDwBIGktsBToys7hxto7Zvpd9MLO06Cn2nUZvG9UY7rPtY52XXfQPwjY3vB6HDimMYOkFcCK9PIFSY+WVPYC4OmSttUP5Xaz7MrL1RWllX1ox5Vpo11DpW27at1sw1Xrqfc2RbueqZbtuu6gryZp8aoXEauAVaUXLG2JiJGyt9ur5Xaz7Azf87TtGqpr21Xr5vdZtUF+b63UfUXuOHBIw+uDgR0118GsbG7X1jfqDvr3AIslHSZpH2AZsK7mOpiVze3a+katwzsRsUfShcCtFKe2rY6Ih2oqvls/q7v5c97vuQZdbtd16LshqRkY5PfWlCL2Gno0M7MB5Vk2zcwy4qBvZpaRgQ36kuZL2ihpW3qe1yTPEknfkfSQpPslfbiOclO+b0p6VtL6Eso8VdKjksYkjTZZv6+k69P670oa7rTMNss9TtK9kvZIOrOMMtss9xOSHk7f6SZJZZyLn41u7TtV6tY+0pMiYiAfwOeA0bQ8ClzRJM8RwOK0/GZgJ7B/1eWmdScCHwTWd1jeHOBx4HBgH+AHwJGT8nwcuCotLwOuL+HzbafcYeCtwLXAmSV9r+2U+z7gtWn5Y2W835we3dp3Knw/XdlHevUxsD19isvg16TlNcAZkzNExGMRsS0t7wB2AUNVl5vK2wQ832FZ0DAFQET8ApiYAqBVnb4GnCip2QVFpZYbEU9GxP3Arzosa6bl3hERL6aXmynOm7f2dWvfqUq39pGeNMhB/8CI2AmQng+YKrOkoyl6AY/XWW4Jmk0BcFCrPBGxB3gOeGMN5VZhpuWeD3yj0hoNnm7tO1Xp1j7Sk/p6Pn1JtwFvarLq0hluZyHwVWB5REzbKy2r3JK0MwVAW9MEVFBuFdouV9I5wAjw3kpr1Ie6te90Sbf2kZ7U10E/Ik5qtU7SU5IWRsTO1DB3tcj3emAD8MmI2FxXuSVqZwqAiTzjkuYCbwCeqaHcKrRVrqSTKALYeyPipRrq1Ve6te90Sbf2kZ40yMM764DlaXk5cPPkDOmS+ZuAayPixrrKLVk7UwA01ulM4PZIR6wqLrcK05Yr6Sjgy8CHIqLqf7qDqFv7TlW6tY/0pm4fSa7qQTEetwnYlp7np/QRijsbAZwD/BK4r+GxpOpy0+tvAbuBn1P0Mk7poMzTgccoxlQvTWmXUwQ9gP2AG4Ex4HvA4SV9xtOV+8703n4G/Bh4qKZybwOeavhO13W7PfbTo1v7TsXvqSv7SC8+PA2DmVlGBnl4x8zMJnHQNzPLiIO+mVlGHPTNzDLioG9mlhEHfTOzjDjom5ll5P8AIIQzKrV2e+4AAAAASUVORK5CYII=\n",
      "text/plain": [
       "<Figure size 432x288 with 4 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "plot_weights(model)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 41,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAY4AAAD4CAYAAAD7CAEUAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjEsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy8QZhcZAAAb0klEQVR4nO3de5Bc5Xnn8e9vJYPXLDYSGgktQh4pkcnKbFnGU8CWHWJbtiQgRrgCXqkcGC6pcRGoTbz5w+OwKS6O17IrXjaUHVhh5BWpmItxCCpBpMgyymapcBmBwAKMNcgYBk3pYnErw0JEnv3jvAOtUY+m35nuc3pGv09VV5/znPecfk73e+bpc+kzigjMzMwa9W+qTsDMzCYWFw4zM8viwmFmZllcOMzMLIsLh5mZZZladQJjNWPGjOjs7Kw6DZuktm7dui8iOsp+Xfdra6Vm9esJWzg6Ozvp6+urOg2bpCT9sorXdb+2VmpWv/ahKjMzy+LCYWZmWVw4zMwsy4Q9xzFZdfbeC8Bzq86pOBOz5hnq1+C+PRl4j8PMzLK4cJiZWZZJd6jKu8RmZq3lPQ4zM8viwmFW3xRJd0n6maSnJf0nSdMlbZK0Iz1PA1DhBkn9kp6QdOrQQiR1p/Y7JHVXtzpmzePCYVbfScCGiPgt4CPA00AvsDkiFgCb0zjAWcCC9OgBbgSQNB24GjgdOA24eqjYmE1kLhxmw7z66qsAxwK3AETEWxHxMrAcWJuarQXOS8PLgVuj8CBwnKTZwFJgU0Tsj4iXgE3AsvLWxHJ19t570HlSq8+Fw2yYnTt3AhwAvi/pMUnfk3QMMCsiBgHS88w0y4nACzWLGEixkeIHkdQjqU9S3969e5u+PmbN5sJhNsyBAwcA3gfcGBEfBX7Nu4el6lGdWBwmfnAgYnVEdEVEV0dH6TfkNctWauGQdLKkbTWPVyX9saRrJL1YEz+7zLzMas2ZMwfgrYh4KIXuAk4FdqdDUKTnPWn6AMU5kXcWAew6TNxsQiu1cETEMxGxKCIWAR8DXgfuTpOvH5oWEfeVmZdZrRNOOAHgLUknp9Bi4ClgHTB0ZVQ3cE8aXgdclK6uOgN4JR3K2ggskTQtnRRfkmJmE1qVPwBcDDwbEb+U6u3Rm1XqeeBvJB0F7AQuofiidaeky9L0C1Lb+4CzgX6KL0OXAETEfklfAx5J7a6LiP3lrYJZa1RZOFYAt9WMXynpIqAP+JN0FcpBJPVQXO7I3LlzS0nSjlhvRERXnfji4YGICOCKeguJiDXAmibnZlapSk6Op29x5wI/TKEbgd8AFgGDwLfrzeeTiGZm1avqqqqzgEcjYjdAROyOiLcj4l+Bmyl+LGVmZm2oqsKxkprDVENXqiSfB7aXnpGZmTWk9HMckt4HfBb4Uk34W5IWUVzj/tywaWZm1kZKLxwR8Tpw/LDYhWXnYWZmY+NfjpuZWRYXDjMzy+LCYWZmWVw4zMwsiwuHmZllceEwM7MsVd6ryiaB2v+W9tyqcyrMxMzK4j0OMzPL4sJhZmZZXDjMzCyLC4eZmWVx4TAzsywuHGZmlsWFw8zMsrhwmJlZFhcOsxFImiLpMUnr0/g8SQ9J2iHpDklHpfjRabw/Te+sWcZXU/wZSUurWROz5nLhMBvZHwFP14x/E7g+IhYALwGXpfhlwEsR8ZvA9akdkhYCK4APA8uAv5I0paTczVrGhcOsvvcA5wDfA5Ak4NPAXWn6WuC8NLw8jZOmL07tlwO3R8SbEfELoB84rZz0zVqniv85/hzwGvA2cCAiuiRNB+4AOin+5/gXIuKlsnOzyS/j3lonAZcCx6bx44GXI+JAGh8ATkzDJwIvAETEAUmvpPYnAg/WLLN2nndI6gF6AObOnZu1PmZVqGqP41MRsSgiutJ4L7A5HQLYnMbNKrF+/XoovtRsrQmrTtMYZdrh5nk3ELE6IroioqujoyM3XbPStcvdcZcDn0zDa4EtwFeqSsaObA888ADAcWnv+L3A+4H/mWJT017HHGBXmmWAYg9lQNJU4APA/pr4kNp5zCasKvY4AvgHSVvTLjrArIgYBEjPM+vNKKlHUp+kvr1795aUrh1pvvGNbwA8ERGdFCe3fxIRXwTuB85PzbqBe9LwujROmv6TiIgUX5GuupoHLAAeLmUlzFqoij2Oj0fELkkzgU2SftbojBGxGlgN0NXVdcguv1mLfQW4XdKfA48Bt6T4LcBfS+qn2NNYARART0q6E3gKOABcERFvl5+2WXOVXjgiYld63iPpboqrTHZLmh0Rg5JmA3vKzsusnojYQnHolIjYSZ2roiLi/wEXjDD/14Gvty5Ds/KVeqhK0jGSjh0aBpYA2zl4V7/2EICZmbWZsvc4ZgF3F5e4MxX4QURskPQIcKeky4DnGeHbm5mZVa/UwpF29T9SJ/4rYHGZuZiZ2dj4l+NmZpbFhcPMzLK4cJiZWZZ2+eW4mY0i4z5bZi3lPQ4zM8viwmFmZllcOMzMLIsLh5mZZXHhMDOzLC4cZmaWxYXDzGwS6Oy996BLtlvJhcPMzLK4cJiZWRYXDjMzy+LCYWZmWVw4zMwsiwuH2TAvvPACwIckPS3pSUl/BCBpuqRNknak52kpLkk3SOqX9ISkU4eWJak7td8hqbv+K5pNLC4cZsNMnToVYCAi/gNwBnCFpIVAL7A5IhYAm9M4wFnAgvToAW6EotAAVwOnA6cBVw8VG7OJrNTCIekkSffX+SZ3jaQXJW1Lj7PLzMus1uzZswFeB4iI14CngROB5cDa1GwtcF4aXg7cGoUHgeMkzQaWApsiYn9EvARsApaVtiJmLVL2/+M4APxJRDwq6Vhgq6RNadr1EfEXJedjdliSOoGPAg8BsyJiECAiBiXNTM1OBF6omW0gxUaKD3+NHoo9FebOndvcFTBrgVL3OCJiMCIeTcO13+TM2o6kfwf8CPjjiHj1cE3rxOIw8YMDEasjoisiujo6OsaWrFmJKjvHMeybHMCV6cTimpGOA0vqkdQnqW/v3r0lZWpHKFEUjb+JiL9Nsd3pEBTpeU+KDwAn1cw7B9h1mLjZhFZJ4ajzTe5G4DeARcAg8O168/mbmZUhIgA+CDwdEf+jZtI6YOjKqG7gnpr4RenqqjOAV9IhrY3AEknT0pehJSlmNqGV/j/HJb2HYd/kImJ3zfSbgfVl52U25IEHHgA4Hvi0pG0p/KfAKuBOSZcBzwMXpGn3AWcD/RQn1S8BiIj9kr4GPJLaXRcR+0tZCbMWKrVwSBJwC8O+yUmaPXTSEfg8sL3MvMxqfeITnwDYGhFddSYvHh6IYhflinrLiog1wJqmJmhWsbL3OD4OXAj8dNg3uZWSFlGcOHwO+FLJeZmZWYNKLRwR8X+pf6XJfWXmYWZmY+dfjpuZWRYXDjMzy+LCYWZmWVw4zMwsiwuHmZllceEwM7MsLhxmZpbFhcPMzLK4cJiZWRYXDjMzy+LCYWZmWVw4zMwsiwuHmZllceEwM7MsLhxmZpbFhcPMzLK4cJiZWRYXDjMzy9I2hUPSMknPSOqX1Ft1PmbN4r5tk01bFA5JU4DvAmcBC4GVkhZWm5XZ+Llv22TUFoUDOA3oj4idEfEWcDuwvOKczJrBfdsmHUVE1Tkg6XxgWUT8QRq/EDg9Iq4c1q4H6EmjJwPPtCilGcC+Fi27bF6XsflgRHSMdyGN9O0R+nU7f27tnBs4v8NpSr+e2oxMmkB1YodUtIhYDaxueTJSX0R0tfp1yuB1qdyofbtev27ndW3n3MD5laFdDlUNACfVjM8BdlWUi1kzuW/bpNMuheMRYIGkeZKOAlYA6yrOyawZ3Ldt0mmLQ1URcUDSlcBGYAqwJiKerDCllh8OK5HXpULj6NvtvK7tnBs4v5Zri5PjZmY2cbTLoSozM5sgXDjMzCyLCwcgabqkTZJ2pOdpddoskvTPkp6U9ISk/1xFriMZ7bYWko6WdEea/pCkzvKzbEwD6/JfJT2VPofNkj5YRZ7j0UifS+02SHpZ0vph8Xnpc9yRPtejKsqvO7XZIam7Jr4lfYbb0mNmk/Iacz+X9NUUf0bS0mbk04zcJHVKeqPmvbqp2bk1XUQc8Q/gW0BvGu4FvlmnzYeABWn43wODwHFV557ymQI8C8wHjgIeBxYOa/OHwE1peAVwR9V5j2NdPgW8Lw1f3q7rMt4+l6YtBj4HrB8WvxNYkYZvAi4vOz9gOrAzPU9Lw9PStC1AVwV9o24/p7jdy+PA0cC8tJwpbZJbJ7C96j6Z8/AeR2E5sDYNrwXOG94gIn4eETvS8C5gDzDuX2A2SSO3tahdx7uAxZLq/TitaqOuS0TcHxGvp9EHKX4bMdGM2ucAImIz8FptLH1un6b4HA87f4vzWwpsioj9EfESsAlY1uQ8ao2nny8Hbo+INyPiF0B/Wl475DbhuHAUZkXEIEB6PuxutaTTKL5VPFtCbo04EXihZnwgxeq2iYgDwCvA8aVkl6eRdal1GfD3Lc2oNbL63DDHAy+nzxFGf49ald9on9X306GXP2vSH8jx9PPcflVmbgDzJD0m6R8l/XYT82qJtvgdRxkk/Rg4oc6kqzKXMxv4a6A7Iv61Gbk1QSO3bGnoti5toOE8Jf0+0AX8TkszGqNm9bl6i64Ty/4sm5Df4fL4YkS8KOlY4EfAhcCtuTlmvN5obVrd/8eT2yAwNyJ+JeljwN9J+nBEvNrE/JrqiCkcEfGZkaZJ2i1pdkQMpsKwZ4R27wfuBf5bRDzYolTHopHbWgy1GZA0FfgAsL+c9LI0dIsOSZ+h+AP3OxHxZkm5ZWlGnxvBPuA4SVPTN9cx3cakCfkNAJ+sGZ9DcW6DiHgxPb8m6QcUh3LGWzjG089bfeuXMecWxYmONwEiYqukZynOqfY1Mb+m8qGqwjpg6IqQbuCe4Q3SVSt3A7dGxA9LzK0RjdzWonYdzwd+kjpsuxl1XSR9FPhfwLkRkfMHt52M2udGkj63+yk+x+z5G9RIfhuBJZKmpauulgAbJU2VNANA0nuA3wW2NyGn8fTzdcCKdGXTPGAB8HATchp3bpI6VPzfFiTNT7ntbGJuzVf12fl2eFAcZ9wM7EjP01O8C/heGv594F+AbTWPRVXnXrMOZwM/pzjvclWKXUfxxxXgvcAPKU4KPgzMrzrncazLj4HdNZ/DuqpzbkWfS+P/BOwF3qD4xro0xeenz7E/fa5HV5TfpSmHfuCSFDsG2Ao8ATwJ/CVNuoJpPP2cYg/1WYrb1p9VQb+tmxvwe+l9ehx4FPhc1f1ztIdvOWJmZll8qMrMzLK4cJiZWRYXDjMzyzJhL8edMWNGdHZ2Vp2GTVJbt27dF03438y53K+tlZrVryds4ejs7KSvr20vc7YJ4NJLL2X9+vXMnDmT7duLq0WvueYabr75ZoBjJG0D/jQi7oPiJnkUv1R/G/gvEbExxZeRrhyiuOJoVYrPo7j1xHSKq2UujOJ2FCNyv7ZWkvTLZizHh6rsiHXxxRezYcOGQ+Jf/vKXAZ6KiEU1RWMhxbX5H6a4H9NfSZqSrr//LnAWxY30Vqa2AN8Ero+IBcBLFEXHbMJz4bAj1plnnsn06dMbbT7STfLq3tyupBsRmlXChcNsmO985zsACyWtqfk/FCPdxG6keMM3IpTUI6lPUt/evXubtyJmLeLC0WY6e++ls/feqtM4Yl1++eU8++yzAE9R3Hzu22lS7s3zGr6pXkSsjoiuiOjq6GiXO/XbkKFt0tvluybsyXGzVpg1a1bt6M3A0H/eO9xN7OrFm3IjQrN25D0OsxqDg4O1o5/n3ZvzjXSTvLo3t4viXj6tvhGhWSW8x2FHrJUrV7Jlyxb27dvHnDlzuPbaa9myZQvbtm2D4gqpTwFfAoiIJyXdSXEI6wBwRUS8DSDpSoo7xU4B1kTEk+klvgLcLunPgceAW8pcP7NWceGwI9Ztt912SOyyy4orZiU9FRHn1k6LiK8DXx8+T7pk97468Z0099+TmrUFFw5rSO2JwedWnVNhJmZWNZ/jMDOzLC4cZmaWxYXDzMyyuHCYmVkWFw4zM8viwmFmZllcOMzMLIsLh5mZZXHhMDOzLC4cZmaWxYXDzMyyuHCYmVkWFw4zM8viwmFmZllcOMzMLIsLh5mZZXHhMDOzLKMWDklrJO2RtL0mNl3SJkk70vO0FJekGyT1S3pC0qk183Sn9jskddfEPybpp2meGySp2StpZmbN08gex/8Glg2L9QKbI2IBsDmNA5wFLEiPHuBGKAoNcDVwOsX/YL56qNikNj018w1/LTMzayOjFo6I+D/A/mHh5cDaNLwWOK8mfmsUHgSOkzQbWApsioj9EfESsAlYlqa9PyL+OSICuLVmWWZm1obGeo5jVkQMAqTnmSl+IvBCTbuBFDtcfKBOvC5JPZL6JPXt3bt3jKmbmdl4NPvkeL3zEzGGeF0RsToiuiKiq6OjY4wpmhUuvfRSZs6cySmnnPJObP/+/Xz2s58FOMXn78zqG2vh2J0OM5Ge96T4AHBSTbs5wK5R4nPqxM1a7uKLL2bDhg0HxVatWsXixYsBtuPzd2Z1jbVwrAOGvll1A/fUxC9K387OAF5Jh7I2AkskTUsb1RJgY5r2mqQz0rexi2qWZdZSZ555JtOnTz8ods8999Dd/c5Og8/fmdUxdbQGkm4DPgnMkDRA8e1qFXCnpMuA54ELUvP7gLOBfuB14BKAiNgv6WvAI6nddRExdML9coort/4t8PfpYVaJ3bt3M3v2bKA4fyep5efvJPVQ7Jkwd+7c8a+EWYuNWjgiYuUIkxbXaRvAFSMsZw2wpk68Dzjl0DnM2krLzt9FxGpgNUBXV9eI5/jM2oV/OW5WY9asWQwODgI+f2c2EhcOsxrnnnsua9cO/UTJ5+/M6hn1UJXZZLVy5Uq2bNnCvn37mDNnDtdeey29vb184QtfgOLw6Sv4/J3ZIVw47Ih122231Y1v3rwZSdsj4p3zeD5/Z/YuH6oyM7Ms3uMws8p19t77zvBzq86pMBNrhPc4zMwsiwuHmZllceEwM7MsLhxmZpbFhcPMzLK4cJiZWRYXDjMzy+LCYWZmWVw4zMwsiwuHmZllceEwM7MsLhxmZpbFhcPMzLK4cJiZWRYXDjMzy+LCYWZmWVw4zMwsiwuHmZllceEwM7Ms4yockp6T9FNJ2yT1pdh0SZsk7UjP01Jckm6Q1C/pCUmn1iynO7XfIal7PDl19t77zsPMzJqvGXscn4qIRRHRlcZ7gc0RsQDYnMYBzgIWpEcPcCMUhQa4GjgdOA24eqjYmJlZ+2nFoarlwNo0vBY4ryZ+axQeBI6TNBtYCmyKiP0R8RKwCVjWgrzMcvzHdtubNmsX4y0cAfyDpK2SelJsVkQMAqTnmSl+IvBCzbwDKTZS/BCSeiT1Serbu3fvOFM3G5X3ps3qGG/h+HhEnEqx4Vwh6czDtFWdWBwmfmgwYnVEdEVEV0dHR362ZuPjvWkzxlk4ImJXet4D3E3xrWp32mhIz3tS8wHgpJrZ5wC7DhM3q1ope9Pek7aJZsyFQ9Ixko4dGgaWANuBdcDQsdxu4J40vA64KB0PPgN4JW18G4Elkqal3fglKWZWpZ+VtTftPWmbaKaOY95ZwN2Shpbzg4jYIOkR4E5JlwHPAxek9vcBZwP9wOvAJQARsV/S14BHUrvrImL/OPIya4Z/gWJvWtJBe9MRMZixN/3JYfEtLc7bjlBDP0F4btU5LX+tMReOiNgJfKRO/FfA4jrxAK4YYVlrgDVjzcWsUbW/7xlpA/v1r38NaW+8Zm/6Ot7dm17FoXvTV0q6neJE+CupuGwE/nvNCfElwFebu0Zm5RvPHofZpLR7926A35L0ON6bNjuEC4fZMPPnzwd4quYyXMB702ZDfK8qMzPL4sJhZmZZXDjMzCyLC4eZmWVx4TAzsywuHGZmlsWFw8zMsrhwmJlZFv8A0GyCaOR2KWZl8B6HmZllceEwM7MsLhxmZpbFhcPMzLK4cJiZWRYXDjMzy+LCYWZmWVw4zMwsiwuHmZllceEwM7MsLhxmZpbFhcPMzLK4cJiZWRYXDjMzy9I2hUPSMknPSOqX1Ft1PmbN4r5tk01bFA5JU4DvAmcBC4GVkhZWm5XZ+Llv22TUFoUDOA3oj4idEfEWcDuwvOKczJrBfdsmHUVE1Tkg6XxgWUT8QRq/EDg9Iq4c1q4H6EmjJwPPlJro6GYA+6pOIoPzHdkHI6JjvAtppG83qV9PtM9yiPMu18kRcex4F9Iu/zpWdWKHVLSIWA2sbn06YyOpLyK6qs6jUc63FKP27Wb06wn63jjvkknqa8Zy2uVQ1QBwUs34HGBXRbmYNZP7tk067VI4HgEWSJon6ShgBbCu4pzMmsF92yadtjhUFREHJF0JbASmAGsi4smK0xqLtj2MNgLn22Il9u0J994kzrtcTcm7LU6Om5nZxNEuh6rMzGyCcOEwM7MsLhwZJE2XtEnSjvQ8bYR23anNDkndNfEt6dYT29JjZovyPOwtLiQdLemONP0hSZ01076a4s9IWtqK/JqVr6ROSW/UvJ83lZFvFTL63gZJL0taPyw+L713O9J7eVQ5mU+c7abm9SbU9jPevMe0HUWEHw0+gG8BvWm4F/hmnTbTgZ3peVoanpambQG6WpzjFOBZYD5wFPA4sHBYmz8EbkrDK4A70vDC1P5oYF5azpQ2zrcT2F51v2iXvpemLQY+B6wfFr8TWJGGbwIub6fcq95uavKYUNtPk/LO3o68x5FnObA2Da8FzqvTZimwKSL2R8RLwCZgWUn5QWO3uKhdj7uAxZKU4rdHxJsR8QugPy2vXfM9kjTS94iIzcBrtbH0Xn2a4r077PwtMhG2myETbftpRt7ZXDjyzIqIQYD0XG+X+UTghZrxgRQb8v20O/hnLfrjN9rrH9QmIg4ArwDHNzhvs40nX4B5kh6T9I+SfrvFuVapkb43kuOBl9N7B+V8rrUmwnbTaB4HtWmD7eeQnA7z2k3bjtridxztRNKPgRPqTLqq0UXUiQ1d8/zFiHhR0rHAj4ALgVvzsxzz64/WpqFbvzTZePIdBOZGxK8kfQz4O0kfjohXm51kGZrQ90ZcdJ1YUz/XSbDdNJLHaG2q2H6GlLoduXAMExGfGWmapN2SZkfEoKTZwJ46zQaAT9aMz6E4RktEvJieX5P0A4rdy2ZvAI3c4mKozYCkqcAHgP0NzttsY843igO0bwJExFZJzwIfAppyP56yNaHvjWQfcJykqembZtM/10mw3dTmMZG2n+E5He61m7Yd+VBVnnXA0NUe3cA9ddpsBJZImpauHlkCbJQ0VdIMAEnvAX4X2N6CHBu5xUXtepwP/CR1nnXAinT1xTxgAfBwC3JsSr6SOlT8vwskzU/57mxxvlVppO/VlT7b+yneu+z5m2AibDdDJtr2M+68x7QdlXHGf7I8KI4HbgZ2pOfpKd4FfK+m3aUUJ8b6gUtS7BhgK/AE8CTwl7ToigvgbODnFFdZXJVi1wHnpuH3Aj9M+T0MzK+Z96o03zPAWSW9r2PKF/i99F4+DjwKfK7qPtIGfe+fgL3AGxTfMJem+Pz03vWn9/LoNsy90u1mvP0xTSt9+xlv3mPZjnzLETMzy+JDVWZmlsWFw8zMsrhwmJlZFhcOMzPL4sJhZmZZXDjMzCyLC4eZmWX5/1E/bFKr0vRVAAAAAElFTkSuQmCC\n",
      "text/plain": [
       "<Figure size 432x288 with 4 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "plot_weights(quant_model)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
